# coding=utf-8
# Copyright 2020 The Tensor2Robot Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
"""T2RModels for duck pose prediction task."""

from __future__ import absolute_import
from __future__ import division

from __future__ import print_function
from typing import Optional, Tuple, Text
from absl import logging
import gin
import numpy as np
from six.moves import range
from tensor2robot.layers import vision_layers
from tensor2robot.models import abstract_model
from tensor2robot.models import critic_model
from tensor2robot.models import regression_model
from tensor2robot.preprocessors import abstract_preprocessor
from tensor2robot.research.dql_grasping_lib import tf_modules
from tensor2robot.utils import tensorspec_utils
import tensorflow.compat.v1 as tf  # tf
from tensorflow.contrib import framework
from tensorflow.contrib import layers

TensorSpec = tensorspec_utils.ExtendedTensorSpec  # pylint: disable=invalid-name

TRAIN = tf.estimator.ModeKeys.TRAIN
PREDICT = tf.estimator.ModeKeys.PREDICT


class DefaultPoseEnvContinuousPreprocessor(
    abstract_preprocessor.AbstractPreprocessor):
  """The default pose env preprocessor.

  This preprocessor simply converts all images from uint8 to float32.
  """

  def get_in_feature_specification(self, mode):
    """See base class."""
    feature_spec = tensorspec_utils.TensorSpecStruct()
    feature_spec['state/image'] = TensorSpec(
        shape=self._model_feature_specification_fn(mode).state.image.shape,
        dtype=tf.uint8,
        name=self._model_feature_specification_fn(mode).state.image.name,
        data_format=self._model_feature_specification_fn(
            mode).state.image.data_format)
    feature_spec['action/pose'] = self._model_feature_specification_fn(
        mode).action.pose
    return feature_spec

  def get_in_label_specification(
      self, mode):
    """See base class."""
    return tensorspec_utils.flatten_spec_structure(
        self._model_label_specification_fn(mode))

  def get_out_feature_specification(
      self, mode):
    """See base class."""
    return tensorspec_utils.flatten_spec_structure(
        self._model_feature_specification_fn(mode))

  def get_out_label_specification(
      self, mode):
    """See base class."""
    return tensorspec_utils.flatten_spec_structure(
        self._model_label_specification_fn(mode))

  def _preprocess_fn(
      self,
      features,
      labels,
      mode,
  ):
    """See base class."""
    features.state.image = tf.image.convert_image_dtype(features.state.image,
                                                        tf.float32)
    return features, labels


@gin.configurable
class PoseEnvContinuousMCModel(critic_model.CriticModel):
  """Continuous MC model for Pose Env."""

  def get_action_specification(self):
    pose_spec = TensorSpec(shape=(2,), dtype=tf.float32, name='pose')
    return tensorspec_utils.TensorSpecStruct(
        pose=pose_spec)

  def get_state_specification(self):
    image_spec = TensorSpec(
        shape=(64, 64, 3), dtype=tf.float32, name='state/image',
        data_format='jpeg')
    return tensorspec_utils.TensorSpecStruct(
        image=image_spec)

  @property
  def default_preprocessor_cls(self):
    return DefaultPoseEnvContinuousPreprocessor

  def get_label_specification(self, mode):
    del mode
    reward_spec = TensorSpec(shape=(), dtype=tf.float32, name='reward')
    return tensorspec_utils.TensorSpecStruct(reward=reward_spec)

  def _q_features(self, state, action, is_training=True, reuse=tf.AUTO_REUSE):
    """Compute feature representation of state and action.

    This method is re-used for computing partial context features for RL^2
    algorithm.

    Args:
      state: Image observation tensor.
      action: Continuous-valued 1-D action tensor.
      is_training: Bool specifying whether this graph is being constructed for
        training or not.
      reuse: Bool or TF re-use ENUM.

    Returns:
      Q values of shape (num_batch, feature_size) where feature_size is
        h * w * 32.
    """
    # State-action embedding module.
    net = state
    channels = 32
    with tf.variable_scope('q_features', reuse=reuse):
      with framework.arg_scope(tf_modules.argscope(is_training=is_training)):
        for layer_index in range(3):
          net = layers.conv2d(net, channels, kernel_size=3)
          logging.info('conv%d %s', layer_index, net.get_shape())
      action_context = layers.fully_connected(action, channels)
      _, h, w, _ = net.shape.as_list()
      num_batch_net = tf.shape(net)[0]
      num_batch_context = tf.shape(action_context)[0]
      # assume num_batch_context >= num_batch_net
      net = tf.tile(net, [num_batch_context // num_batch_net, 1, 1, 1])
      action_context = tf.reshape(action_context,
                                  [num_batch_context, 1, 1, channels])
      action_context = tf.tile(action_context, [1, h, w, 1])
      net += action_context
      net = tf.layers.flatten(net)
    return net

  def q_func(self,
             features,
             scope,
             mode,
             config = None,
             params = None,
             reuse=tf.AUTO_REUSE):
    del params
    is_training = mode == TRAIN
    with tf.variable_scope(scope, reuse=reuse, use_resource=True):
      image = tf.image.convert_image_dtype(features.state.image, tf.float32)
      net = self._q_features(
          image, features.action.pose, is_training=is_training, reuse=reuse)
      net = layers.stack(net, layers.fully_connected, [100, 100])
      net = layers.fully_connected(
          net,
          num_outputs=1,
          normalizer_fn=None,
          weights_regularizer=None,
          activation_fn=None)
      return {'q_predicted': tf.squeeze(net, 1)}

  def pack_features(self, state, context, timestep, actions):
    del context, timestep
    batch_obs = np.expand_dims(state, 0)
    return tensorspec_utils.TensorSpecStruct(
        state=batch_obs, action=actions)


class DefaultPoseEnvRegressionPreprocessor(
    abstract_preprocessor.AbstractPreprocessor):
  """The default pose env preprocessor.

  This preprocessor simply converts all images from uint8 to float32.
  """

  def get_in_feature_specification(
      self, mode):
    """See base class."""
    feature_spec = tensorspec_utils.TensorSpecStruct()
    feature_spec['state'] = TensorSpec(
        shape=self._model_feature_specification_fn(mode).state.shape,
        dtype=tf.uint8,
        name=self._model_feature_specification_fn(mode).state.name,
        data_format=self._model_feature_specification_fn(
            mode).state.data_format)
    return feature_spec

  def get_in_label_specification(
      self, mode):
    """See base class."""
    return tensorspec_utils.flatten_spec_structure(
        self._model_label_specification_fn(mode))

  def get_out_feature_specification(
      self, mode):
    """See base class."""
    return tensorspec_utils.flatten_spec_structure(
        self._model_feature_specification_fn(mode))

  def get_out_label_specification(
      self, mode):
    """See base class."""
    return tensorspec_utils.flatten_spec_structure(
        self._model_label_specification_fn(mode))

  def _preprocess_fn(
      self, features,
      labels,
      mode
  ):
    """See base class."""
    features.state = tf.image.convert_image_dtype(features.state, tf.float32)
    return features, labels


@gin.configurable
class PoseEnvRegressionModel(regression_model.RegressionModel):
  """Continuous regression output model for Pose Env."""

  @property
  def default_preprocessor_cls(self):
    return DefaultPoseEnvRegressionPreprocessor

  def get_feature_specification(self, mode):
    del mode
    state_spec = TensorSpec(
        shape=(64, 64, 3), dtype=tf.float32, name='state/image',
        data_format='jpeg')
    return tensorspec_utils.TensorSpecStruct(state=state_spec)

  def get_label_specification(self, mode):
    del mode
    target_spec = TensorSpec(
        shape=(self._action_size), dtype=tf.float32, name='target_pose')
    reward_spec = TensorSpec(shape=(1), dtype=tf.float32, name='reward')
    return tensorspec_utils.TensorSpecStruct(
        target_pose=target_spec, reward=reward_spec)

  def pack_features(self, state, context, timestep):
    del context, timestep
    batch_obs = np.expand_dims(state, 0)
    return tensorspec_utils.TensorSpecStruct(state=batch_obs)

  @property
  def action_size(self):
    return self._action_size

  def get_config(self):
    """This model trains fairly quickly so evaluate frequently."""
    return tf.estimator.RunConfig(
        save_checkpoints_steps=2000,
        keep_checkpoint_max=5)

  def a_func(
      self,
      features,
      scope,
      mode,
      config = None,
      params = None,
      reuse=tf.AUTO_REUSE,
      context_fn=None,
  ):
    """A (state) regression function.

    This function can return a stochastic or a deterministic tensor.

    Args:
      features: This is the first item returned from the input_fn and parsed by
        tensorspec_utils.validate_and_pack. A spec_structure which fulfills the
        requirements of the self.get_feature_spefication.
      scope: String specifying variable scope.
      mode: (ModeKeys) Specifies if this is training, evaluation or prediction.
      config: Optional configuration object. Will receive what is passed to
        Estimator in config parameter, or the default config. Allows updating
        things in your model_fn based on configuration such as num_ps_replicas,
        or model_dir.
      params: An optional dict of hyper parameters that will be passed into
        input_fn and model_fn. Keys are names of parameters, values are basic
        python types. There are reserved keys for TPUEstimator, including
        'batch_size'.
      reuse: Whether or not to reuse variables under variable scope 'scope'.
      context_fn: Optional python function that takes in features and returns
        new features of same shape. For merging information like in RL^2.

    Returns:
      outputs: A {key: Tensor} mapping. The key 'action' is required.
    """
    del config
    is_training = mode == TRAIN
    image = tf.image.convert_image_dtype(features.state, tf.float32)
    with tf.variable_scope(scope, reuse=reuse, use_resource=True):
      with tf.variable_scope('state_features', reuse=reuse, use_resource=True):
        feature_points, end_points = vision_layers.BuildImagesToFeaturesModel(
            image,
            is_training=is_training,
            normalizer_fn=layers.layer_norm)
      del end_points
      if context_fn:
        feature_points = context_fn(feature_points)
      estimated_pose, _ = vision_layers.BuildImageFeaturesToPoseModel(
          feature_points, num_outputs=self._action_size)
    return {'inference_output': estimated_pose,
            'state_features': feature_points}

  def loss_fn(self, labels, inference_outputs, mode, params=None):
    del mode
    return tf.losses.mean_squared_error(
        labels=labels.target_pose,
        predictions=inference_outputs['inference_output'],
        weights=labels.reward)
